import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self, yield_threshold: float = 0.5):
        """
        Define trainable physical parameter for plasticity yield threshold.
        Initialized to 0.5 to balance plastic effects based on feedback.

        Args:
            yield_threshold (float): logarithmic strain clamp threshold.
        """
        super().__init__()
        self.yield_threshold = nn.Parameter(torch.tensor(yield_threshold))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Correct deformation gradient by clamping logarithmic principal strains.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        # SVD of deformation gradient
        U, Sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), Sigma: (B,3), Vh: (B,3,3)

        # Clamp singular values to avoid numerical problems
        Sigma_clamped = torch.clamp_min(Sigma, 1e-6)  # (B,3)

        # Logarithmic principal strains
        log_sigma = torch.log(Sigma_clamped)  # (B,3)

        # Enforce positive yield threshold via softplus
        yield_thresh = torch.nn.functional.softplus(self.yield_threshold)  # scalar

        # Clamp logarithmic strains within ±yield_threshold
        epsilon_clamped = torch.clamp(log_sigma, min=-yield_thresh, max=yield_thresh)  # (B,3)

        # Compute corrected singular values
        Sigma_corrected = torch.exp(epsilon_clamped)  # (B,3)

        # Recompose corrected deformation gradient
        F_corrected = torch.matmul(U, torch.matmul(torch.diag_embed(Sigma_corrected), Vh))  # (B,3,3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(self, youngs_modulus_log: float = 10.18, poissons_ratio_sigmoid: float = -0.5):
        """
        Define trainable continuous physical parameters for Corotated Elasticity.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio_sigmoid (float): parameter before sigmoid for Poisson's ratio.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))
        self.poissons_ratio_sigmoid = nn.Parameter(torch.tensor(poissons_ratio_sigmoid))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient via Corotated Elasticity.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.shape[0]

        # Material parameters
        youngs_modulus = self.youngs_modulus_log.exp()  # scalar
        poissons_ratio = self.poissons_ratio_sigmoid.sigmoid() * 0.49  # scalar in (0, 0.49)

        # Lamé parameters
        mu = youngs_modulus / (2.0 * (1.0 + poissons_ratio))  # scalar
        la = youngs_modulus * poissons_ratio / ((1.0 + poissons_ratio) * (1.0 - 2.0 * poissons_ratio))  # scalar

        # SVD of deformation gradient
        U, Sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), Sigma: (B,3), Vh: (B,3,3)

        # Clamp singular values
        Sigma_clamped = torch.clamp_min(Sigma, 1e-6)  # (B,3)

        # Rotation matrix R
        R = torch.matmul(U, Vh)  # (B,3,3)

        # Compute determinant
        J = Sigma_clamped.prod(dim=1).view(B, 1, 1)  # (B,1,1)

        # Identity tensor
        I = torch.eye(3, device=F.device, dtype=F.dtype).unsqueeze(0).expand(B, 3, 3)  # (B,3,3)

        # Reshape scalars for broadcast
        mu = mu.view(-1, 1, 1) if mu.dim() == 0 else mu
        la = la.view(-1, 1, 1) if la.dim() == 0 else la

        # Corotated stress term
        corotated = 2.0 * mu * (F - R)  # (B,3,3)

        # Volumetric stress term
        volumetric = la * J * (J - 1).view(B, 1, 1) * I  # (B,3,3)

        # First Piola-Kirchhoff stress tensor P
        P = corotated + volumetric  # (B,3,3)

        # Kirchhoff stress tau = P @ F^T
        Ft = F.transpose(1, 2)  # (B,3,3)
        kirchhoff_stress = torch.matmul(P, Ft)  # (B,3,3)

        return kirchhoff_stress
